import java.util.Arrays;

public class Leet1227 {
    public int countSquares(int[][] matrix) {
        int[][] dp = new int[matrix.length][matrix[0].length];
        int num = 0;
        for (int[] ints : matrix)
            for (int j = 0; j < matrix[0].length; j++)
                if (ints[j] == 1)
                    num++;
        while (true){
            boolean flag = false;
            for (int i = 1; i < matrix.length; i++)
                for (int j = 1; j < matrix[0].length; j++)
                    if (matrix[i][j]==1&&matrix[i-1][j]==1&&matrix[i][j-1]==1&&matrix[i-1][j-1]==1){
                        dp[i][j] = 1;
                        num++;
                        flag = true;
                    }
            matrix = dp.clone();
            dp = new int[matrix.length][matrix[0].length];
            if (!flag)
                break;
        }
        return num;
    }
}
